package org.funny.nn.som.old;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * 自组织映射(Self-organizing map, SOM) 本例用于解决旅行商问题。地图上有N个城市。一个商人要走一圈经过所有的城市，并回到开始的城市。
 * 希望走过的路线尽可能短。
 *
 * GIT HUB上搜索 SOM只找到Python的内容。这里转化成JAVA。
 * https://github.com/diego-vicente/som-tsp/
 * 运行完毕后，可以在diagrams 文件夹下查看png文件，体验一下，神经(绿色点)如何一步步，逼近目标。(红色为城市，神经点最后会贴近旅行商的路线)
 *
 * @author: LinLW
 */
public class Main {
    // 初始学习率
    private static final double INIT_LEARN_RATE=0.618;
    // 衰减率
    private static final double DECAYED_RATE =0.997;

    private static final int ITERATIONS =1000;

    public static List<City> readCities2(){
        List<City> cities=new ArrayList<>();
        //固定提供一些基础数据即可
        cities.add(new City("福州",119.28, 26.08));
        cities.add(new City("泉州",118.67, 24.88));
        cities.add(new City("厦门",118.09, 24.48));
        cities.add(new City("漳州",117.72, 24.52));
        cities.add(new City("宁德",119.53, 26.66));
        cities.add(new City("龙岩",117.02, 25.08));
        cities.add(new City("三明",117.36, 26.13));
        cities.add(new City("莆田",119.00, 25.43));
        cities.add(new City("南平",118.12, 27.33));

        return cities;
    }
    public static List<City> readCities(){
        List<City> cities=new ArrayList<>();
        //固定提供一些基础数据即可
        cities.add(new City("沈阳市",123.429092,41.796768));
        cities.add(new City("长春市",125.324501,43.886841));
        cities.add(new City("哈尔滨市",126.642464,45.756966));
        cities.add(new City("北京市",116.405289,39.904987));
        cities.add(new City("天津市",117.190186,39.125595));
        cities.add(new City("呼和浩特市",111.751990,40.841490));
        cities.add(new City("银川市",106.232480,38.486440));
        cities.add(new City("太原市",112.549248,37.857014));
        cities.add(new City("石家庄市",114.502464,38.045475));
        cities.add(new City("济南市",117.000923,36.675808));
        cities.add(new City("郑州市",113.665413,34.757977));
        cities.add(new City("西安市",108.948021,34.263161));
        cities.add(new City("武汉市",114.298569,30.584354));
        cities.add(new City("南京市",118.76741,32.041546));
        cities.add(new City("合肥市",117.283043,31.861191));
        cities.add(new City("上海市",121.472641,31.231707));
        cities.add(new City("长沙市",112.982277,28.19409));
        cities.add(new City("南昌市",115.892151,28.676493));
        cities.add(new City("杭州市",120.15358,30.287458));
        cities.add(new City("福州市",119.306236,26.075302));
        cities.add(new City("广州市",113.28064,23.125177));
        cities.add(new City("台北市",121.5200760,25.0307240));
        cities.add(new City("海口市",110.199890,20.044220));
        cities.add(new City("南宁市",108.320007,22.82402));
        cities.add(new City("重庆市",106.504959,29.533155));
        cities.add(new City("昆明市",102.71225,25.040609));
        cities.add(new City("贵阳市",106.713478,26.578342));
        cities.add(new City("成都市",104.065735,30.659462));
        cities.add(new City("兰州市",103.834170,36.061380));
        cities.add(new City("西宁市",101.777820,36.617290));
        cities.add(new City("拉萨市",91.11450,29.644150));
        cities.add(new City("乌鲁木齐市",87.616880,43.826630));
        cities.add(new City("香港",114.165460,22.275340));
        cities.add(new City("澳门",113.549130,22.198750));


        return cities;
    }



    public static void main(String[] args){
        List<City> cities= readCities();
        normalize(cities);// 归一化
        System.out.println("归一化后");
        for(City city:cities){
            System.out.println(city.getCity()+" x="+city.getX()+", y="+city.getY());
        }

        som(cities, ITERATIONS);
    }
    public static void som(List<City>problem,int iterations){
        som(problem,iterations,INIT_LEARN_RATE);
    }
    /*
    * Solve the TSP using a Self-Organizing Map.
    * @param: [cities, iterations, learnRate]
    * @return: int[]
    */
    public static void som(List<City>cities,int iterations,double learnRate){

        // The population size is 8 times the number of cities
        // 根据另一个文章，这里最佳值是，sqrt(5*sqrt(n));
        // 但经过测试，如果用5经常会出现奇怪的解答
        int n = cities.size() * 8;

        // 产生一个神经网络
       List<Point> network = Neuron.generateNetwork(n);
        double radix=n/10.0;

        boolean hasBreak=false;
        for(int i=0; i<iterations; i++){
            if (i==0||i==1||i==2||i==3||i==5||i==7||
                    i==10||i==20||i==30||i==50||i==70||
                    i==100||i==200||i==300||i==500||i==700||
                    i==1000||i==2000||i==3000||i==5000||i==7000) {//画图
                //plotNetwork(cities, network, name = 'diagrams/{:05d}.png'.format(i))
                System.out.println("\t> 第 " + i + "/" + iterations+" 次循环");
                Plot.plotNetwork(cities, network, "diagrams/round"+i+".png");
            }
            Collections.shuffle(cities);// 随机打乱城市的顺序。防止排在前面的城市学习效果偏高。
            for(City city:cities) {

                int winner_idx = Distance.selectClosest(network, city);

                // Generate a filter that applies changes to the winner's gaussian
                double[] gaussian = Neuron.getNeighborhood(winner_idx, radix, network.size());

                // Update the network's weights (closer to the city)
                int tempIndex = 0;
                for (Point row : network) {
                    row.setX(row.getX()+ gaussian[tempIndex] * learnRate * (city.getX() - row.getX()));
                    row.setY(row.getY()+ gaussian[tempIndex] * learnRate * (city.getY() - row.getY()));
                    tempIndex++;
                }
            }
            if (radix < 1) {
                System.out.println("半径已经过于衰减, 停止执行。当前在第 " + i + " 轮循环");
                hasBreak = true;
                break;
            }
            if (learnRate < 0.01) {
                System.out.println("学习率已经过于衰减, 停止执行。当前在第 " + i + " 轮循环");
                hasBreak = true;
                break;
            }
            learnRate *= DECAYED_RATE;
            radix *= DECAYED_RATE;
        }
        //没有break的话
        if(!hasBreak) {
            System.out.println("完成 " + iterations + " 次循环");
        }
        Plot.plotNetwork(cities, network, "diagrams/round"+iterations+".png");

        cities =  Neuron.getRoute(cities, network);
        for(City c:cities){
            System.out.println(c.getCity());
        }
    }

    /**
     * 归一化，把城市的x和y 都换算成0-1之间数。最西边的X为0 最东边的X为1 同理最北边的Y为1最那边的Y为0
     * @param points 所有城市列表
     */
    public static void normalize(List<City> points){

        double maxX=Float.MIN_VALUE, maxY=Float.MIN_VALUE;
        double minX=Float.MAX_VALUE, minY=Float.MAX_VALUE;
        for(Point p:points){
            maxX=p.getX()>maxX?p.getX():maxX;
            minX=p.getX()<minX?p.getX():minX;
            maxY=p.getY()>maxY?p.getY():maxY;
            minY=p.getY()<minY?p.getY():minY;
        }
        for(Point p :points){
            p.setX((p.getX()-minX)/ (maxX-minX));
            p.setY((p.getY()-minY)/ (maxY-minY));
        }
    }
}
